import torch
import numpy as np
import pickle
import os
import torchvision
import random
cpath = os.path.dirname(__file__)
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from opacus import PrivacyEngine

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Resize, InterpolationMode


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(512, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out

num_padding = 14
batch_size = 128
transform = transforms.Compose([transforms.Pad(num_padding),Resize(size=(28, 28), interpolation=InterpolationMode.BICUBIC),transforms.ToTensor()])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
images, labels = next(iter(trainloader))


index = 0
print('Padding rate', 2*num_padding/(28+2*num_padding))
plt.figure(figsize=(4, 4))
plt.imshow(images[index].permute(1, 2, 0))
plt.title(f"Label: {trainset.classes[labels[index]]}")
plt.axis('off')
plt.show()